{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "gpuType": "V100",
      "authorship_tag": "ABX9TyNFolHGhUkgZ1a1kS3P0+qh",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/DanielWarfield1/MLWritingAndResearch/blob/main/LoRA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# LoRA\n",
        "based on a combination of this:\n",
        "https://huggingface.co/docs/peft/task_guides/semantic_segmentation_lora\n",
        "and this:\n",
        "\n",
        "https://www.youtube.com/watch?v=iYr1xZn26R8\n",
        "\n",
        "https://github.com/huggingface/peft/issues/493\n",
        "\n",
        "\n",
        "Recommended runtime: v100 high RAM. A100 high RAM if you use a larger BLOOM model.\n",
        "\n",
        "Note, I generally prioritize ease of comparing a model and it's fine tuned counterpart over inference time.\n",
        "\n",
        "---"
      ],
      "metadata": {
        "id": "Z9s6u5X9z_uj"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Downloading Dependencies\n",
        "- **bitsandbytes:** for representing models using smaller datatypes, saving on memory.\n",
        "- **datasets:** for downloading datasets\n",
        "- **accelerate:** required dependency for machine learning interoperability\n",
        "- **loralib:** LoRA implementation\n",
        "- **peft:** a general \"parameter efficient fine tuning\" module, our interface for LoRA\n",
        "- **transformers:** for downloading and using pre-trained transformers from huggingface."
      ],
      "metadata": {
        "id": "01zIxuZX0mhA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -q bitsandbytes datasets accelerate loralib\n",
        "!pip install -q git+https://github.com/huggingface/peft.git git+https://github.com/huggingface/transformers.git"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e7aYz5D70CCg",
        "outputId": "d115803a-8ae2-44f0-fde5-04c2364e3d11"
      },
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Loading Pre-Trained Model\n",
        "in this model we're using the BLOOM, decoder only, causal language model. This is a permissively source language model trained on a variety of data.\n",
        "\n",
        "We'll be using the 560m parameter version to save on GPU memory, but if you use an A100 instance you should be able to run the 3b parameter version. While not thoroughly tested, all code should work for any flavor of BLOOM"
      ],
      "metadata": {
        "id": "xUGgwzqK124J"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Importing dependencies and downloading pre-trained bloom model\n",
        "\"\"\"\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import bitsandbytes as bnb\n",
        "from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n",
        "\n",
        "#loading model\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    # \"bigscience/bloom-3b\",\n",
        "    # \"bigscience/bloom-1b1\",\n",
        "    \"bigscience/bloom-560m\",\n",
        "    torch_dtype=torch.float16,\n",
        "    device_map='auto',\n",
        ")\n",
        "\n",
        "#loading tokenizer for this model (which turns text into an input for the model)\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bigscience/tokenizer\")"
      ],
      "metadata": {
        "id": "34LkV4O514v-"
      },
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Setting up LoRA\n",
        "- **r:** the rank of the A and B matrices\n",
        "- **lora_alpha:** this is a pretty controversial parameter. A lot of people hava a lot of ideas about it. You can consider it a scaling factor, and by default it should be equal to `r`, as far as I understand.\n",
        "- **target_modules:** the portions of the model we want to optimize with LoRA. the BLOOM module has parameters named `query_key_value` which we want to optimize.\n",
        "- **lora_dropout:** dropout is a technique which hides inputs to suppress the model from overfitting (called regularization). This is a probability of being hidden.\n",
        "- **bias:** neural networks typically have two paramet per connection, a \"weight\" and a \"bias\". We're only training weights in this example.\n",
        "- **task_type:** not super necessary, used in the superclass `PeftConfig`. Setting to `CAUSAL_LM` because the specific language model we're using is \"causal\"."
      ],
      "metadata": {
        "id": "5bFZ2oKI2fsI"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Setting up LoRA using parameter efficient fine tuning\n",
        "\"\"\"\n",
        "\n",
        "from peft import LoraConfig, get_peft_model\n",
        "\n",
        "#defining how LoRA will work in this particular example\n",
        "config = LoraConfig(\n",
        "    r=8,\n",
        "    lora_alpha=8,\n",
        "    target_modules=[\"query_key_value\"],\n",
        "    lora_dropout=0.05,\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\"\n",
        ")\n",
        "\n",
        "#this actually overwrites the model in memory, so\n",
        "#the rename is only for ledgibility.\n",
        "peft_model = get_peft_model(model, config)"
      ],
      "metadata": {
        "id": "BXq6vYyb29yQ"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Printing Trainable Parameter Difference"
      ],
      "metadata": {
        "id": "qsSod_or3Fax"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Comparing parameters before and after LoRA\n",
        "\"\"\"\n",
        "\n",
        "trainable_params = 0\n",
        "all_param = 0\n",
        "\n",
        "#iterating over all parameters\n",
        "for _, param in peft_model.named_parameters():\n",
        "    #adding parameters to total\n",
        "    all_param += param.numel()\n",
        "    #adding parameters to trainable if they require a graident\n",
        "    if param.requires_grad:\n",
        "        trainable_params += param.numel()\n",
        "\n",
        "#printing results\n",
        "print(f\"trainable params: {trainable_params}\")\n",
        "print(f\"all params: {all_param}\")\n",
        "print(f\"trainable: {100 * trainable_params / all_param:.2f}%\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1ep1kImT3KJc",
        "outputId": "6c61169b-8cae-4996-f1a7-f1e9dd93768a"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "trainable params: 786432\n",
            "all params: 560001024\n",
            "trainable: 0.14%\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Loading Dataset\n",
        "this is the stanford question answering dataset (SQUAD), which we'll use to fine tune BLOOM to improve performance on question answering."
      ],
      "metadata": {
        "id": "BCii0BM441AF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Loading SQUAD dataset\n",
        "\"\"\"\n",
        "\n",
        "from datasets import load_dataset\n",
        "qa_dataset = load_dataset(\"squad_v2\")"
      ],
      "metadata": {
        "id": "GPbR9Yk132nU"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Re-Formatting\n",
        "We're going to get the LLM to learn a specific format (a common use of fine tuning).\n",
        "\n",
        "```\n",
        "**CONTEXT:**\n",
        "{context}\n",
        "\n",
        "**QUESTION:**\n",
        "{question}\n",
        "\n",
        "**ANSWER:**\n",
        "{answer}</s>\n",
        "```\n",
        "\n",
        "So, we'll reformat our SQUAD dataset to respect that format."
      ],
      "metadata": {
        "id": "0qqx4lwA5x9q"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Reformatting SQUAD to respect our defined structure\n",
        "\"\"\"\n",
        "\n",
        "#defining a function for reformatting\n",
        "def create_prompt(context, question, answer):\n",
        "  if len(answer[\"text\"]) < 1:\n",
        "    answer = \"Cannot Find Answer\"\n",
        "  else:\n",
        "    answer = answer[\"text\"][0]\n",
        "  prompt_template = f\"CONTEXT:\\n{context}\\n\\nQUESTION:\\n{question}\\n\\nANSWER:\\n{answer}</s>\"\n",
        "  return prompt_template\n",
        "\n",
        "#applying the reformatting function to the entire dataset\n",
        "mapped_qa_dataset = qa_dataset.map(lambda samples: tokenizer(create_prompt(samples['context'], samples['question'], samples['answers'])))"
      ],
      "metadata": {
        "id": "gpDfILFX6Tud"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Training our LoRA model on SQUAD\n",
        "Updating the decomposed matrices to improve the model on question answering, and teach it the desired structure."
      ],
      "metadata": {
        "id": "ytep61CF6lQb"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Fine Tuning\n",
        "This code is largly co-opted. In the absence of a rigid validation\n",
        "procedure, the best practice is to just copy a successful tutorial or,\n",
        "better yet, directly from the documentation.\n",
        "\"\"\"\n",
        "\n",
        "import transformers\n",
        "\n",
        "trainer = transformers.Trainer(\n",
        "    model=peft_model,\n",
        "    train_dataset=mapped_qa_dataset[\"train\"],\n",
        "    args=transformers.TrainingArguments(\n",
        "        per_device_train_batch_size=4,\n",
        "        gradient_accumulation_steps=4,\n",
        "        warmup_steps=100,\n",
        "        max_steps=100,\n",
        "        learning_rate=1e-3,\n",
        "        fp16=True,\n",
        "        logging_steps=1,\n",
        "        output_dir='outputs',\n",
        "    ),\n",
        "    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
        ")\n",
        "peft_model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n",
        "trainer.train()"
      ],
      "metadata": {
        "id": "bVcKf0h16wL-",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "ebec74cc-0f13-407a-8c36-d52a0dd39edd"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ],
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [100/100 00:52, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>3.454900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>3.347400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>3.374100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>3.523800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>3.493500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>3.413400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>3.236500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>3.487500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>3.538700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>3.469000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>3.251100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>3.474400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>3.349900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>3.084300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>3.427200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>3.301600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>3.233200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>3.479600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>3.304600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>3.118300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>3.333300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>3.381700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>3.145900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>3.299600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>3.192000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>3.188100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>3.238600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>2.978600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>3.104600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>3.153400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>3.090600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>3.171500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>3.197400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>2.906000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>3.080200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>3.001500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>3.040000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>3.063700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>3.070500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>3.117000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>3.053800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>3.031400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>2.908200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>3.042400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>2.925400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>3.022900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>3.061900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>2.868400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>2.894600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>2.938900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>51</td>\n",
              "      <td>3.037900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>52</td>\n",
              "      <td>2.849900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>53</td>\n",
              "      <td>2.959000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>54</td>\n",
              "      <td>2.837100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>55</td>\n",
              "      <td>3.030200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>56</td>\n",
              "      <td>2.755000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>57</td>\n",
              "      <td>2.951700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>58</td>\n",
              "      <td>2.898800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>59</td>\n",
              "      <td>2.913300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>60</td>\n",
              "      <td>2.935700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>61</td>\n",
              "      <td>3.061400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>62</td>\n",
              "      <td>2.992700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>63</td>\n",
              "      <td>2.877900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>64</td>\n",
              "      <td>3.002000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>65</td>\n",
              "      <td>2.911300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>66</td>\n",
              "      <td>2.866500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>67</td>\n",
              "      <td>3.074400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>68</td>\n",
              "      <td>3.065500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>69</td>\n",
              "      <td>3.021000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>70</td>\n",
              "      <td>2.816200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>71</td>\n",
              "      <td>3.010800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>72</td>\n",
              "      <td>2.708400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>73</td>\n",
              "      <td>2.871200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>74</td>\n",
              "      <td>2.838100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>75</td>\n",
              "      <td>2.954700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>76</td>\n",
              "      <td>2.916800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>77</td>\n",
              "      <td>2.926200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>78</td>\n",
              "      <td>2.829200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>79</td>\n",
              "      <td>2.743400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>80</td>\n",
              "      <td>2.854500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>81</td>\n",
              "      <td>2.858800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>82</td>\n",
              "      <td>2.924200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>83</td>\n",
              "      <td>2.927600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>84</td>\n",
              "      <td>2.863300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>85</td>\n",
              "      <td>3.074700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>86</td>\n",
              "      <td>2.734500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>87</td>\n",
              "      <td>2.845900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>88</td>\n",
              "      <td>3.171200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>89</td>\n",
              "      <td>2.971300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>90</td>\n",
              "      <td>3.006000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>91</td>\n",
              "      <td>2.909100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>92</td>\n",
              "      <td>2.802900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>93</td>\n",
              "      <td>2.794200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>94</td>\n",
              "      <td>2.946800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>95</td>\n",
              "      <td>2.895400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>96</td>\n",
              "      <td>2.656600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>97</td>\n",
              "      <td>2.869700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>98</td>\n",
              "      <td>2.698000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>99</td>\n",
              "      <td>3.025400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>100</td>\n",
              "      <td>3.083400</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ]
          },
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TrainOutput(global_step=100, training_loss=3.054349970817566, metrics={'train_runtime': 53.5128, 'train_samples_per_second': 29.899, 'train_steps_per_second': 1.869, 'total_flos': 740744642985984.0, 'train_loss': 3.054349970817566, 'epoch': 0.01})"
            ]
          },
          "metadata": {},
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Saving Locally\n",
        "saving our LoRA fine tune results."
      ],
      "metadata": {
        "id": "7ki4RGWm7UGt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Saving the LoRA fine tuning locally\n",
        "\"\"\"\n",
        "model_id = \"BLOOM-560m-LoRA\"\n",
        "peft_model.save_pretrained(model_id)"
      ],
      "metadata": {
        "id": "I-tQ5ng27ts0"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Checking File Size\n",
        "Compare this to the size of the initial model download to get an idea of the memory savings."
      ],
      "metadata": {
        "id": "eT2bM7ou76Ld"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!ls -lh {model_id}"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "I-pnuRL48F1S",
        "outputId": "cd9f8b06-4c3f-4669-eecb-5751f1e25820"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "total 3.1M\n",
            "-rw-r--r-- 1 root root  482 Nov  6 14:17 adapter_config.json\n",
            "-rw-r--r-- 1 root root 3.1M Nov  6 14:17 adapter_model.bin\n",
            "-rw-r--r-- 1 root root 5.3K Nov  6 14:17 README.md\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Testing"
      ],
      "metadata": {
        "id": "uGBiNfNbm98k"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "\"\"\"Helper Function for Comparing Results\n",
        "\"\"\"\n",
        "\n",
        "from IPython.display import display, Markdown\n",
        "\n",
        "def make_inference(context, question):\n",
        "\n",
        "    #turn the input into tokens\n",
        "    batch = tokenizer(f\"**CONTEXT:**\\n{context}\\n\\n**QUESTION:**\\n{question}\\n\\n**ANSWER:**\\n\", return_tensors='pt', return_token_type_ids=False)\n",
        "    #move the tokens onto the GPU, for inference\n",
        "    batch = batch.to(device='cuda')\n",
        "\n",
        "    #make an inference with both the fine tuned model and the raw model\n",
        "    with torch.cuda.amp.autocast():\n",
        "        #I think inference time would be faster if these were applied,\n",
        "        #but the fact that LoRA is not applied allows me to experiment\n",
        "        #with before and after fine tuning simultaniously\n",
        "\n",
        "        #raw model\n",
        "        peft_model.disable_adapter_layers()\n",
        "        output_tokens_raw = model.generate(**batch, max_new_tokens=200)\n",
        "\n",
        "        #LoRA model\n",
        "        peft_model.enable_adapter_layers()\n",
        "        output_tokens_qa = peft_model.generate(**batch, max_new_tokens=200)\n",
        "\n",
        "    #display results\n",
        "    display(Markdown(\"# Raw Model\\n\"))\n",
        "    display(Markdown((tokenizer.decode(output_tokens_raw[0], skip_special_tokens=True))))\n",
        "    display(Markdown(\"\\n# QA Model\\n\"))\n",
        "    display(Markdown((tokenizer.decode(output_tokens_qa[0], skip_special_tokens=True))))"
      ],
      "metadata": {
        "id": "vSZZjyp1mPoJ"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "context = \"You are a monster, and you eat yellow legos.\"\n",
        "question = \"What is the best food?\"\n",
        "\n",
        "make_inference(context, question)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 504
        },
        "id": "qbjfkm-4om-G",
        "outputId": "5db3c6f0-e070-4cfd-e5a0-d3b238e4137a"
      },
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "# Raw Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nYou are a monster, and you eat yellow legos.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nThe best food is the one that is not poisonous, but that is\nnot poisonous at all.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nThe best food is the one that is not poisonous, but that is\nnot poisonous at all.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nThe best food is the one that is not poisonous, but that is\nnot poisonous at all.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nThe best food is the one that is not poisonous, but that is\nnot poisonous at all.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nThe best food is the one that is not poisonous, but that is\nnot poisonous at all.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "\n# QA Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nYou are a monster, and you eat yellow legos.\n\n**QUESTION:**\nWhat is the best food?\n\n**ANSWER:**\nyellow legos"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "context = \"you are a math wizard\"\n",
        "question = \"what is 1+1 equal to?\"\n",
        "\n",
        "make_inference(context, question)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 597
        },
        "id": "OrFE2IvG_B1I",
        "outputId": "a680fe51-2bd5-40ea-8d51-36c931b92456"
      },
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "# Raw Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nyou are a math wizard\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal to 1.0\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 is equal"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "\n# QA Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nyou are a math wizard\n\n**QUESTION:**\nwhat is 1+1 equal to?\n\n**ANSWER:**\n1+1 = 1"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "context = \"Answer the riddle\"\n",
        "question = \"What gets bigger the more you take away?\"\n",
        "\n",
        "make_inference(context, question)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 270
        },
        "id": "XNKu-ugS_0wV",
        "outputId": "7c3c94ea-b7ba-4545-d1f9-0e4f68f6cf78"
      },
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "# Raw Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nAnswer the riddle\n\n**QUESTION:**\nWhat gets bigger the more you take away?\n\n**ANSWER:**\nThe answer is that the more you take away, the more you get away."
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "\n# QA Model\n"
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "<IPython.core.display.Markdown object>"
            ],
            "text/markdown": "**CONTEXT:**\nAnswer the riddle\n\n**QUESTION:**\nWhat gets bigger the more you take away?\n\n**ANSWER:**\nCannot Find Answer"
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "OaDBbgIPA1q-"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}